
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import numpy as np
import tensorflow as tf
from tensorflow import losses


def compute_iou(x1,y1,w1,h1, x2,y2,w2,h2):
    # x1...:[b,16,16,5]
    xmin1 = x1 - 0.5*w1
    xmax1 = x1 + 0.5*w1
    ymin1 = y1 - 0.5*h1
    ymax1 = y1 + 0.5*h1

    xmin2 = x2 - 0.5*w2
    xmax2 = x2 + 0.5*w2
    ymin2 = y2 - 0.5*h2
    ymax2 = y2 + 0.5*h2

    # (xmin1,ymin1,xmax1,ymax1) (xmin2,ymin2,xmax2,ymax2)
    interw = np.minimum(xmax1,xmax2) - np.maximum(xmin1,xmin2)
    interh = np.minimum(ymax1, ymax2) - np.maximum(ymin1, ymin2)
    inter = interw * interh
    union = w1*h1 +w2*h2 - inter
    iou = inter / (union + 1e-6)
    # [b,16,16,5]
    return iou


def yolo_loss(detector_mask, matching_gt_boxes, matching_classes_oh, gt_boxes_grid, y_pred):
    ANCHORS = [0.57273, 0.677385, 1.87446, 2.06253, 3.33843, 5.47434, 7.88282, 3.52778, 9.77052, 9.16828]
    GRIDSZ = 16
    # detector_mask: [b,16,16,5,1]
    # matching_gt_boxes: [b,16,16,5,5] x-y-w-h-l
    # matching_classes_oh: [b,16,16,5,2] l1-l2
    # gt_boxes_grid: [b,40,5] x-y-wh-l
    # y_pred: [b,16,16,5,7] x-y-w-h-conf-l0-l1

    anchors = np.array(ANCHORS).reshape(5, 2)

    # create starting position for each grid anchors
    # [16,16]
    x_grid = tf.tile(tf.range(GRIDSZ), [GRIDSZ])
    # [1,16,16,1,1]
    # [b,16,16,5,2]
    x_grid = tf.reshape(x_grid, (1, GRIDSZ, GRIDSZ, 1, 1))
    x_grid = tf.cast(x_grid, tf.float32)
    # [b,16_1,16_2,1,1]=>[b,16_2,16_1,1,1]
    y_grid = tf.transpose(x_grid, (0, 2, 1, 3, 4))
    xy_grid = tf.concat([x_grid, y_grid], axis=-1)
    # [1,16,16,1,2]=> [b,16,16,5,2]
    xy_grid = tf.tile(xy_grid, [y_pred.shape[0], 1, 1, 5, 1])

    # [b,16,16,5,7] x-y-w-h-conf-l1-l2
    pred_xy = tf.sigmoid(y_pred[..., 0:2])
    pred_xy = pred_xy + xy_grid
    # [b,16,16,5,2]
    pred_wh = tf.exp(y_pred[..., 2:4])
    # [b,16,16,5,2] * [5,2] => [b,16,16,5,2]
    pred_wh = pred_wh * anchors

    n_detector_mask = tf.reduce_sum(tf.cast(detector_mask > 0., tf.float32))
    # [b,16,16,5,1] * [b,16,16,5,2]
    #
    xy_loss = detector_mask * tf.square(matching_gt_boxes[..., :2] - pred_xy) / (n_detector_mask + 1e-6)
    xy_loss = tf.reduce_sum(xy_loss)
    wh_loss = detector_mask * tf.square(tf.sqrt(matching_gt_boxes[..., 2:4]) - \
                                        tf.sqrt(pred_wh)) / (n_detector_mask + 1e-6)
    wh_loss = tf.reduce_sum(wh_loss)

    # 4.1 coordinate loss
    coord_loss = xy_loss + wh_loss

    # 4.2 class loss
    # [b,16,16,5,2]
    pred_box_class = y_pred[..., 5:]
    # [b,16,16,5]
    true_box_class = tf.argmax(matching_classes_oh, -1)
    # [b,16,16,5] vs [b,16,16,5,2]
    class_loss = losses.sparse_categorical_crossentropy( \
        true_box_class, pred_box_class, from_logits=True)
    # [b,16,16,5] => [b,16,16,5,1]* [b,16,16,5,1]
    class_loss = tf.expand_dims(class_loss, -1) * detector_mask
    class_loss = tf.reduce_sum(class_loss) / (n_detector_mask + 1e-6)

    # 4.3 object loss
    # nonobject_mask
    # iou done!
    # [b,16,16,5]
    x1, y1, w1, h1 = matching_gt_boxes[..., 0], matching_gt_boxes[..., 1], \
                     matching_gt_boxes[..., 2], matching_gt_boxes[..., 3]
    # [b,16,16,5]
    x2, y2, w2, h2 = pred_xy[..., 0], pred_xy[..., 1], pred_wh[..., 0], pred_wh[..., 1]
    ious = compute_iou(x1, y1, w1, h1, x2, y2, w2, h2)
    # [b,16,16,5,1]
    ious = tf.expand_dims(ious, axis=-1)

    # [b,16,16,5,1]
    pred_conf = tf.sigmoid(y_pred[..., 4:5])
    # [b,16,16,5,2] => [b,16,16,5, 1, 2]
    pred_xy = tf.expand_dims(pred_xy, axis=4)
    # [b,16,16,5,2] => [b,16,16,5, 1, 2]
    pred_wh = tf.expand_dims(pred_wh, axis=4)
    pred_wh_half = pred_wh / 2.
    pred_xymin = pred_xy - pred_wh_half
    pred_xymax = pred_xy + pred_wh_half

    # [b, 40, 5] => [b, 1, 1, 1, 40, 5]
    true_boxes_grid = tf.reshape(gt_boxes_grid, \
                                 [gt_boxes_grid.shape[0], 1, 1, 1, gt_boxes_grid.shape[1], gt_boxes_grid.shape[2]])
    true_xy = true_boxes_grid[..., 0:2]
    true_wh = true_boxes_grid[..., 2:4]
    true_wh_half = true_wh / 2.
    true_xymin = true_xy - true_wh_half
    true_xymax = true_xy + true_wh_half
    # predxymin, predxymax, true_xymin, true_xymax
    # [b,16,16,5,1,2] vs [b,1,1,1,40,2]=> [b,16,16,5,40,2]
    intersectxymin = tf.maximum(pred_xymin, true_xymin)
    # [b,16,16,5,1,2] vs [b,1,1,1,40,2]=> [b,16,16,5,40,2]
    intersectxymax = tf.minimum(pred_xymax, true_xymax)
    # [b,16,16,5,40,2]
    intersect_wh = tf.maximum(intersectxymax - intersectxymin, 0.)
    # [b,16,16,5,40] * [b,16,16,5,40]=>[b,16,16,5,40]
    intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
    # [b,16,16,5,1]
    pred_area = pred_wh[..., 0] * pred_wh[..., 1]
    # [b,1,1,1,40]
    true_area = true_wh[..., 0] * true_wh[..., 1]
    # [b,16,16,5,1]+[b,1,1,1,40]-[b,16,16,5,40]=>[b,16,16,5,40]
    union_area = pred_area + true_area - intersect_area
    # [b,16,16,5,40]
    iou_score = intersect_area / union_area
    # [b,16,16,5]
    best_iou = tf.reduce_max(iou_score, axis=4)
    # [b,16,16,5,1]
    best_iou = tf.expand_dims(best_iou, axis=-1)

    nonobj_detection = tf.cast(best_iou < 0.6, tf.float32)
    nonobj_mask = nonobj_detection * (1 - detector_mask)
    # nonobj counter
    n_nonobj = tf.reduce_sum(tf.cast(nonobj_mask > 0., tf.float32))

    nonobj_loss = tf.reduce_sum(nonobj_mask * tf.square(-pred_conf)) \
                  / (n_nonobj + 1e-6)
    obj_loss = tf.reduce_sum(detector_mask * tf.square(ious - pred_conf)) \
               / (n_detector_mask + 1e-6)

    loss = coord_loss + class_loss + nonobj_loss + 5 * obj_loss
    # print(coord_loss, class_loss, nonobj_loss, obj_loss)

    return loss, [nonobj_loss + 5 * obj_loss, class_loss, coord_loss]


if __name__ == '__main__':
    from Model import makeDarkNet
    from AncestralDataLoader import groundTruthGenerator, makeDataSet

    model = makeDarkNet()
    db = makeDataSet()
    train_gen = groundTruthGenerator(db)

    img, detector_mask, matching_gt_boxes, matching_classes_oh, gt_boxes_grid = next(train_gen)

    img, detector_mask, matching_gt_boxes, matching_classes_oh, gt_boxes_grid = \
        img[0], detector_mask[0], matching_gt_boxes[0], matching_classes_oh[0], gt_boxes_grid[0]

    y_pred = model(tf.expand_dims(img, axis=0))[0]

    loss, sub_loss = yolo_loss(tf.expand_dims(detector_mask, axis=0),
                               tf.expand_dims(matching_gt_boxes, axis=0),
                               tf.expand_dims(matching_classes_oh, axis=0),
                               tf.expand_dims(gt_boxes_grid, axis=0),
                               tf.expand_dims(y_pred, axis=0)
                               )

    print(loss)
    print(sub_loss)